Intermediate checkpointing for sequential calibration#1152
Intermediate checkpointing for sequential calibration#1152
Conversation
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
📝 WalkthroughWalkthroughAdds sequential calibration checkpointing: new config options, checkpoint registry and saver selection, model metadata persistence for resume, collector state-machine changes to support resuming and warm-up, HuggingFace saver plugin, and unit tests validating save/resume behavior. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant SequentialCalib as sequential_calibrate
participant Collector as LayerActivationCollector
participant CheckpointUtils
participant Saver as ModelCheckpointSaver
participant Model
User->>SequentialCalib: start(calib_func, checkpoint_dir, interval)
SequentialCalib->>CheckpointUtils: detect_sequential_resume_layer(Model, num_layers)
CheckpointUtils->>Model: getattr(_seq_calib_progress)
Model-->>CheckpointUtils: progress or None
CheckpointUtils-->>SequentialCalib: resume_idx, metadata
alt resume available
SequentialCalib->>Collector: prepare_for_resume(resume_idx, forward_loop)
Collector->>Collector: _run_warmup_capture / set modes
end
loop per-layer from resume_idx
SequentialCalib->>Collector: _set_layer_states(layer_idx)
Collector->>Model: forward(pass)
SequentialCalib->>CheckpointUtils: should_save_seq_calib_checkpoint(layer_idx,...)
alt should save
SequentialCalib->>Collector: get_layer_output_metas(up_to_idx)
Collector-->>SequentialCalib: metas
SequentialCalib->>CheckpointUtils: save_sequential_checkpoint(Model, layer_idx, total, checkpoint_dir, metas)
CheckpointUtils->>CheckpointUtils: get_checkpoint_saver(Model)
CheckpointUtils->>Model: setattr(_seq_calib_progress, payload)
CheckpointUtils->>Saver: save_fn(Model, checkpoint_dir)
Saver->>Model: persist checkpoint
end
end
SequentialCalib->>Model: delattr(_seq_calib_progress) (cleanup)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1152 +/- ##
===========================================
- Coverage 70.18% 54.64% -15.54%
===========================================
Files 230 349 +119
Lines 26080 39895 +13815
===========================================
+ Hits 18304 21802 +3498
- Misses 7776 18093 +10317
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 6
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/quantization/conversion.py`:
- Around line 143-145: The seq_calib_progress metadata is restored after the
fast-path return so the resume marker is lost; modify the fast-path in
convert_quantized_model (or the function containing the quantizer_state
fast-path) to set setattr(model, SEQ_CALIB_PROGRESS_ATTR,
metadata["seq_calib_progress"]) (guarded by "seq_calib_progress" in metadata)
before any early return that returns quantizer_state via extra_state, ensuring
the metadata restoration happens regardless of taking the fast-path.
In `@modelopt/torch/quantization/utils/activation_collector.py`:
- Around line 334-345: get_layer_output_metas currently returns the stored
output_meta objects verbatim (from _extract_output_meta) which may include
torch.device info and cause cross-device errors on resume; update
get_layer_output_metas (and the complementary loading path in _seq_calib) to
produce device-agnostic serializable metas: either strip/convert any
torch.device fields to a neutral representation (e.g. store device as string
like "cpu") or move tensors to a canonical device (cpu) before returning, and
ensure the loader remaps that neutral/device-string back to the current runtime
device when rebuilding state; look for references to _decoder_layers,
_LAYER_ATTR, state.output_meta, and _seq_calib to implement symmetric save (in
get_layer_output_metas) and load remapping so forward passes never receive
tensors bound to a stale device.
In `@modelopt/torch/quantization/utils/checkpoint.py`:
- Around line 104-113: The function should_save_seq_calib_checkpoint currently
does a modulo with checkpoint_interval which raises for zero and misbehaves for
negatives; add an upfront guard in should_save_seq_calib_checkpoint to reject
non-positive intervals by checking if checkpoint_interval is not None and
checkpoint_interval > 0 and raise a ValueError (with a clear message referencing
checkpoint_interval) before performing the modulo, so the later logic that uses
(layer_idx + 1) % checkpoint_interval and the other checks can remain unchanged.
- Around line 83-101: The persisted progress payload (stored under
SEQ_CALIB_PROGRESS_ATTR and loaded into progress) must be validated before use:
ensure progress is a dict and contains integer keys "completed_layer_idx" and
"total_layers", then verify completed_layer_idx is within [-1, num_layers - 1]
and that total_layers equals num_layers; if any check fails, raise a clear
ValueError (or return 0, None) instead of proceeding. Update the logic around
the progress variable, completed_layer and saved_total to validate types and
ranges before computing resume_from, printing via print_rank_0, or returning
layer_output_metas.
In `@tests/unit/torch/quantization/test_sequential_calibrate.py`:
- Around line 947-964: Replace the duplicated logic in
test_update_quantize_metadata_includes_progress with a real call to
update_quantize_metadata: set the SEQ_CALIB_PROGRESS_ATTR on the model as you
do, create an empty metadata dict (or config expected by
update_quantize_metadata), call update_quantize_metadata(model, metadata) (or
the correct signature of update_quantize_metadata) so it picks up
SEQ_CALIB_PROGRESS_ATTR, then assert metadata["seq_calib_progress"] equals the
progress value and finally delattr the SEQ_CALIB_PROGRESS_ATTR; reference the
test name test_update_quantize_metadata_includes_progress, the function
update_quantize_metadata, and the attribute SEQ_CALIB_PROGRESS_ATTR when
locating the code to change.
- Around line 984-987: Add an inline comment next to the torch.load call
explaining why using weights_only=False is safe: note that the buffer `buf` is
produced locally by `torch.save(progress, buf)` (not from external input) so
deserializing with `weights_only=False` does not violate the security guideline;
annotate the `loaded = torch.load(buf, weights_only=False)` line with this
justification referencing `buf`, `progress`, and `torch.load`.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 7693e81d-61c3-4338-9c0f-67d1b730e51b
📒 Files selected for processing (8)
modelopt/torch/quantization/config.pymodelopt/torch/quantization/conversion.pymodelopt/torch/quantization/mode.pymodelopt/torch/quantization/model_calib.pymodelopt/torch/quantization/plugins/huggingface.pymodelopt/torch/quantization/utils/activation_collector.pymodelopt/torch/quantization/utils/checkpoint.pytests/unit/torch/quantization/test_sequential_calibrate.py
| progress = getattr(model, SEQ_CALIB_PROGRESS_ATTR, None) | ||
| if progress is None: | ||
| return 0, None | ||
|
|
||
| completed_layer = progress["completed_layer_idx"] | ||
| saved_total = progress["total_layers"] | ||
|
|
||
| if saved_total != num_layers: | ||
| raise ValueError( | ||
| f"Checkpoint was saved with {saved_total} layers but model has " | ||
| f"{num_layers} layers. Cannot resume." | ||
| ) | ||
|
|
||
| resume_from = completed_layer + 1 | ||
| print_rank_0( | ||
| f"Resuming sequential calibration from layer {resume_from} " | ||
| f"(layers 0..{completed_layer} already calibrated)" | ||
| ) | ||
| return resume_from, progress.get("layer_output_metas", {}) |
There was a problem hiding this comment.
Validate the persisted progress payload before using it.
This metadata comes back from disk, so malformed values currently turn into KeyError or an impossible resume point later in the flow. Please validate the schema and enforce completed_layer_idx within [-1, num_layers - 1] before logging or returning it.
🛠️ Suggested guard
def detect_sequential_resume_layer(model: nn.Module, num_layers: int) -> tuple[int, dict | None]:
"""Read checkpoint progress from the model and return ``(resume_layer_idx, layer_output_metas)``.
@@
progress = getattr(model, SEQ_CALIB_PROGRESS_ATTR, None)
if progress is None:
return 0, None
-
- completed_layer = progress["completed_layer_idx"]
- saved_total = progress["total_layers"]
+ if not isinstance(progress, dict):
+ raise ValueError("Malformed sequential calibration checkpoint metadata.")
+ try:
+ completed_layer = int(progress["completed_layer_idx"])
+ saved_total = int(progress["total_layers"])
+ except (KeyError, TypeError, ValueError) as exc:
+ raise ValueError("Malformed sequential calibration checkpoint metadata.") from exc
if saved_total != num_layers:
raise ValueError(
f"Checkpoint was saved with {saved_total} layers but model has "
f"{num_layers} layers. Cannot resume."
)
+ if completed_layer < -1 or completed_layer >= num_layers:
+ raise ValueError(
+ f"Checkpoint completed_layer_idx={completed_layer} is out of range "
+ f"for a model with {num_layers} layers."
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| progress = getattr(model, SEQ_CALIB_PROGRESS_ATTR, None) | |
| if progress is None: | |
| return 0, None | |
| completed_layer = progress["completed_layer_idx"] | |
| saved_total = progress["total_layers"] | |
| if saved_total != num_layers: | |
| raise ValueError( | |
| f"Checkpoint was saved with {saved_total} layers but model has " | |
| f"{num_layers} layers. Cannot resume." | |
| ) | |
| resume_from = completed_layer + 1 | |
| print_rank_0( | |
| f"Resuming sequential calibration from layer {resume_from} " | |
| f"(layers 0..{completed_layer} already calibrated)" | |
| ) | |
| return resume_from, progress.get("layer_output_metas", {}) | |
| progress = getattr(model, SEQ_CALIB_PROGRESS_ATTR, None) | |
| if progress is None: | |
| return 0, None | |
| if not isinstance(progress, dict): | |
| raise ValueError("Malformed sequential calibration checkpoint metadata.") | |
| try: | |
| completed_layer = int(progress["completed_layer_idx"]) | |
| saved_total = int(progress["total_layers"]) | |
| except (KeyError, TypeError, ValueError) as exc: | |
| raise ValueError("Malformed sequential calibration checkpoint metadata.") from exc | |
| if saved_total != num_layers: | |
| raise ValueError( | |
| f"Checkpoint was saved with {saved_total} layers but model has " | |
| f"{num_layers} layers. Cannot resume." | |
| ) | |
| if completed_layer < -1 or completed_layer >= num_layers: | |
| raise ValueError( | |
| f"Checkpoint completed_layer_idx={completed_layer} is out of range " | |
| f"for a model with {num_layers} layers." | |
| ) | |
| resume_from = completed_layer + 1 | |
| print_rank_0( | |
| f"Resuming sequential calibration from layer {resume_from} " | |
| f"(layers 0..{completed_layer} already calibrated)" | |
| ) | |
| return resume_from, progress.get("layer_output_metas", {}) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/quantization/utils/checkpoint.py` around lines 83 - 101, The
persisted progress payload (stored under SEQ_CALIB_PROGRESS_ATTR and loaded into
progress) must be validated before use: ensure progress is a dict and contains
integer keys "completed_layer_idx" and "total_layers", then verify
completed_layer_idx is within [-1, num_layers - 1] and that total_layers equals
num_layers; if any check fails, raise a clear ValueError (or return 0, None)
instead of proceeding. Update the logic around the progress variable,
completed_layer and saved_total to validate types and ranges before computing
resume_from, printing via print_rank_0, or returning layer_output_metas.
| buf = io.BytesIO() | ||
| torch.save(progress, buf) | ||
| buf.seek(0) | ||
| loaded = torch.load(buf, weights_only=False) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, check if the file exists
git ls-files tests/unit/torch/quantization/test_sequential_calibrate.pyRepository: NVIDIA/Model-Optimizer
Length of output: 124
🏁 Script executed:
# Read the specified lines and surrounding context
sed -n '980,995p' tests/unit/torch/quantization/test_sequential_calibrate.pyRepository: NVIDIA/Model-Optimizer
Length of output: 564
Add inline comment to document why weights_only=False is safe.
Line 987 requires an inline comment justifying weights_only=False per security guidelines. The buffer is locally generated via torch.save() and never comes from external input, which satisfies the exception criteria.
Suggested fix
buf = io.BytesIO()
torch.save(progress, buf)
buf.seek(0)
- loaded = torch.load(buf, weights_only=False)
+ # Safe here: `buf` is produced by `torch.save` in this test and never comes from user input.
+ loaded = torch.load(buf, weights_only=False)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/unit/torch/quantization/test_sequential_calibrate.py` around lines 984
- 987, Add an inline comment next to the torch.load call explaining why using
weights_only=False is safe: note that the buffer `buf` is produced locally by
`torch.save(progress, buf)` (not from external input) so deserializing with
`weights_only=False` does not violate the security guideline; annotate the
`loaded = torch.load(buf, weights_only=False)` line with this justification
referencing `buf`, `progress`, and `torch.load`.
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
modelopt/torch/quantization/utils/checkpoint.py (1)
161-165: Consider using the constant for consistency.Line 161 uses the literal
_seq_calib_progresswhile line 83 usesgetattr(model, SEQ_CALIB_PROGRESS_ATTR, ...). Usingsetattrwith the constant would improve maintainability.♻️ Suggested change
- model._seq_calib_progress = { + setattr(model, SEQ_CALIB_PROGRESS_ATTR, { "completed_layer_idx": completed_layer_idx, "total_layers": total_layers, "layer_output_metas": layer_output_metas, - } + })🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/quantization/utils/checkpoint.py` around lines 161 - 165, Replace the literal attribute assignment model._seq_calib_progress with a setattr using the existing SEQ_CALIB_PROGRESS_ATTR constant to match the getattr usage elsewhere; specifically, set the attribute on model via setattr(model, SEQ_CALIB_PROGRESS_ATTR, {...}) using the same keys (completed_layer_idx, total_layers, layer_output_metas) so the code consistently references SEQ_CALIB_PROGRESS_ATTR instead of the hard-coded "_seq_calib_progress".
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/quantization/utils/activation_collector.py`:
- Around line 452-456: The code flips preceding layers to "skip" via
_set_layer_mode before calling _validate_skip_metas, which can leave the
collector partially modified on validation failure; change the logic to validate
the skip metas before flipping modes or, if you must flip first, catch
exceptions from _validate_skip_metas and restore all layers in preceding back to
"original" using _set_layer_mode to avoid a half-resumed state; reference the
surrounding calls _run_warmup_capture, _set_layer_mode(preceding), and
_validate_skip_metas to locate where to add the pre-validation check or the
exception handler+restore.
- Around line 435-443: prepare_for_resume() uses resume_layer_idx without
validating it; add an upfront range check after confirming self._patched to
ensure 0 <= resume_layer_idx <= total_layers (use the class' layer count, e.g.,
len(self.layers) or self.num_layers) and raise a ValueError with a clear message
if out of range. Keep the existing resume_layer_idx == 0 early-return behavior
but perform the validation before any state mutation (before assigning k or
computing preceding) so negative indices or overly large indices are rejected
immediately.
---
Nitpick comments:
In `@modelopt/torch/quantization/utils/checkpoint.py`:
- Around line 161-165: Replace the literal attribute assignment
model._seq_calib_progress with a setattr using the existing
SEQ_CALIB_PROGRESS_ATTR constant to match the getattr usage elsewhere;
specifically, set the attribute on model via setattr(model,
SEQ_CALIB_PROGRESS_ATTR, {...}) using the same keys (completed_layer_idx,
total_layers, layer_output_metas) so the code consistently references
SEQ_CALIB_PROGRESS_ATTR instead of the hard-coded "_seq_calib_progress".
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: dfffce23-8e9d-4e21-aaf3-1d51916f4c88
📒 Files selected for processing (4)
modelopt/torch/quantization/conversion.pymodelopt/torch/quantization/utils/activation_collector.pymodelopt/torch/quantization/utils/checkpoint.pytests/unit/torch/quantization/test_sequential_calibrate.py
✅ Files skipped from review due to trivial changes (1)
- tests/unit/torch/quantization/test_sequential_calibrate.py
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/quantization/conversion.py
| if not self._patched: | ||
| raise RuntimeError( | ||
| "prepare_for_resume() requires _patch_all_layers() to be called first." | ||
| ) | ||
| if resume_layer_idx == 0: | ||
| return | ||
|
|
||
| k = resume_layer_idx | ||
| preceding = range(k - 1) |
There was a problem hiding this comment.
Validate resume_layer_idx before using it.
Negative values will target the last layers instead of being rejected, and values past the end will fail only after resume setup has already started mutating state. Since this value is checkpoint-derived, please range-check it up front.
💡 Suggested fix
if not self._patched:
raise RuntimeError(
"prepare_for_resume() requires _patch_all_layers() to be called first."
)
+ assert self._decoder_layers is not None
+ num_layers = len(self._decoder_layers)
+ if not 0 <= resume_layer_idx < num_layers:
+ raise ValueError(
+ f"resume_layer_idx must be in [0, {num_layers - 1}], got {resume_layer_idx}."
+ )
if resume_layer_idx == 0:
return🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/quantization/utils/activation_collector.py` around lines 435 -
443, prepare_for_resume() uses resume_layer_idx without validating it; add an
upfront range check after confirming self._patched to ensure 0 <=
resume_layer_idx <= total_layers (use the class' layer count, e.g.,
len(self.layers) or self.num_layers) and raise a ValueError with a clear message
if out of range. Keep the existing resume_layer_idx == 0 early-return behavior
but perform the validation before any state mutation (before assigning k or
computing preceding) so negative indices or overly large indices are rejected
immediately.
| self._run_warmup_capture(k - 1, forward_loop) | ||
|
|
||
| for i in preceding: | ||
| self._set_layer_mode(i, "skip") | ||
| self._validate_skip_metas(preceding) |
There was a problem hiding this comment.
Rollback the mode flip if resume validation fails.
_validate_skip_metas() runs after the earlier layers have already been moved to skip, so a missing meta leaves the collector in a half-resumed state on the exception path. Please validate before the flip, or restore the touched layers to original when validation fails.
💡 Suggested fix
- self._run_warmup_capture(k - 1, forward_loop)
-
- for i in preceding:
- self._set_layer_mode(i, "skip")
- self._validate_skip_metas(preceding)
+ try:
+ self._run_warmup_capture(k - 1, forward_loop)
+ self._validate_skip_metas(preceding)
+ except Exception:
+ for i in preceding:
+ self._set_layer_mode(i, "original")
+ state = self._decoder_layers[k - 1]._seq_calib
+ state.mode = "original"
+ state.collected_inputs = []
+ raise
+
+ for i in preceding:
+ self._set_layer_mode(i, "skip")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/quantization/utils/activation_collector.py` around lines 452 -
456, The code flips preceding layers to "skip" via _set_layer_mode before
calling _validate_skip_metas, which can leave the collector partially modified
on validation failure; change the logic to validate the skip metas before
flipping modes or, if you must flip first, catch exceptions from
_validate_skip_metas and restore all layers in preceding back to "original"
using _set_layer_mode to avoid a half-resumed state; reference the surrounding
calls _run_warmup_capture, _set_layer_mode(preceding), and _validate_skip_metas
to locate where to add the pre-validation check or the exception
handler+restore.
What does this PR do?
Type of change: ?
Usage
# Add a code snippet demonstrating how to use thisTesting
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅ / ❌ / N/AAdditional Information
Summary by CodeRabbit
New Features
Bug Fixes / Behavior
Tests